package com.itcoon.gateway.core.cors;

import com.alibaba.nacos.common.utils.StringUtils;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springdoc.core.Constants;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.IOException;
import java.io.StringWriter;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


/**
 * 处理服务 /v3/api-docs接口返回的数据，在servers里添加可以通过网关直接访问的地址
 */
@Slf4j
@Component
public class DocServiceUrlModifyGatewayFilter implements GlobalFilter, Ordered {


    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        //直接用配置类的值，默认值是 /v3/api-docs
        String apiPath = Constants.DEFAULT_API_DOCS_URL;
        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();
        URI uri = request.getURI();

        //非正常状态跳过
        if (response.getStatusCode().value() != HttpStatus.OK.value()) {
            return chain.filter(exchange);
        }

        //非springdoc文档不拦截
        if (!(StringUtils.isNotBlank(uri.getPath()) && uri.getPath().endsWith(apiPath))) {
            return chain.filter(exchange);
        }

        String uriString = uri.toString();
        String gatewayUrl = uriString.substring(0, uriString.lastIndexOf(uri.getPath()));
        DataBufferFactory bufferFactory = response.bufferFactory();


        ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(response) {
            @Override
            public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
                //不处理
                if (!(body instanceof Flux)) {
                    return super.writeWith(body);
                }

                //处理SpringDoc-OpenAPI
                Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
                return super.writeWith(fluxBody.buffer().map(dataBuffers -> {
                    DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory();
                    DataBuffer join = dataBufferFactory.join(dataBuffers);
                    byte[] content = new byte[join.readableByteCount()];
                    join.read(content);
                    DataBufferUtils.release(join);
                    try {
                        // 流转为字符串
                        String responseData = new String(content, StandardCharsets.UTF_8);
                        Map<String, Object> map = JsonUtils.json2Object(responseData, Map.class);
                        //处理返回的数据
                        Object serversObject = map.get("servers");
                        if (null != serversObject) {
                            List<Map<String, Object>> servers = (List<Map<String, Object>>) serversObject;
                            Map<String, Object> gatewayServer = new HashMap<>();
                            //网关地址
                            gatewayServer.put("url", gatewayUrl);
                            gatewayServer.put("description", "Gateway server url");
                            //添加到第1个
                            servers.add(0, gatewayServer);
                            map.put("servers", servers);
                            responseData = JsonUtils.object2Json(map);
                            byte[] uppedContent = responseData.getBytes(StandardCharsets.UTF_8);
                            response.getHeaders().setContentLength(uppedContent.length);
                            return bufferFactory.wrap(uppedContent);
                        }
                    } catch (Exception e) {
                        log.error(e.getMessage(), e);
                    }

                    return bufferFactory.wrap(content);
                }));
            }
        };
        // replace response with decorator
        return chain.filter(exchange.mutate().response(decoratedResponse).build());
    }

    @Override
    public int getOrder() {
        return -2;
    }


    static class JsonUtils {

        public static final ObjectMapper OBJECT_MAPPER = createObjectMapper();

        /**
         * 初始化ObjectMapper
         *
         * @return
         */
        private static ObjectMapper createObjectMapper() {

            ObjectMapper objectMapper = new ObjectMapper();

            objectMapper.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, true);
            objectMapper.configure(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT, true);
            objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
//        objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);

            objectMapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, true);

//        objectMapper.registerModule(new Hibernate4Module().enable(Hibernate4Module.Feature.FORCE_LAZY_LOADING));
            objectMapper.registerModule(new JavaTimeModule());

            return objectMapper;
        }

        public static String object2Json(Object o) {
            StringWriter sw = new StringWriter();
            JsonGenerator gen = null;
            try {
                gen = new JsonFactory().createGenerator(sw);
                OBJECT_MAPPER.writeValue(gen, o);
            } catch (IOException e) {
                throw new RuntimeException("不能序列化对象为Json", e);
            } finally {
                if (null != gen) {
                    try {
                        gen.close();
                    } catch (IOException e) {
                        throw new RuntimeException("不能序列化对象为Json", e);
                    }
                }
            }
            return sw.toString();
        }


        /**
         * 将 json 字段串转换为 对象.
         *
         * @param json  字符串
         * @param clazz 需要转换为的类
         * @return
         */
        public static <T> T json2Object(String json, Class<T> clazz) {
            try {
                return OBJECT_MAPPER.readValue(json, clazz);
            } catch (IOException e) {
                throw new RuntimeException("将 Json 转换为对象时异常,数据是:" + json, e);
            }
        }


    }
}
